from abc import ABC, abstractmethod

import taichi as ti

from .material import Material
from .utils import (
    Vec3,
    Vec4,
    Ray3,
    Ray4,
    Mat3x3,
    Mat4x4,
    Mat3x3Id,
    lorentzBoost,
    IdGenerator,
    Inf,
    deg2rad,
)
from .config import MAX_SHAPE, MAX_SHAPE_ATTRIBUTE, RELATIVISTIC

__all__ = [
    "Shape",
]


@ti.dataclass
class ShapeUtil:
    """形状的底层储存结构"""

    typeId: int
    material: int
    position: Vec3
    velocity: Vec3
    transform: Mat3x3
    lorentz: Mat4x4
    lorentz_inv: Mat4x4
    attribute: ti.types.vector(MAX_SHAPE_ATTRIBUTE, float)

    @ti.func
    def pos2obj(self, pos: Vec4) -> Vec4:
        """将坐标转换到本地坐标系"""
        pos.xyz -= self.position
        if ti.static(RELATIVISTIC):
            pos = self.lorentz @ pos
        else:
            pos.xyz -= self.velocity * pos.w
        pos.xyz = self.transform @ pos.xyz
        return pos

    @ti.func
    def pos2ether(self, pos: Vec4) -> Vec4:
        """将坐标转换到全局坐标系"""
        pos.xyz = self.transform.inverse() @ pos.xyz
        if ti.static(RELATIVISTIC):
            pos = self.lorentz_inv @ pos
        else:
            pos.xyz += self.velocity * pos.w
        pos.xyz += self.position
        return pos

    @ti.func
    def vec2obj(self, vec: Vec4) -> Vec4:
        """将向量转换到本地坐标系"""
        if ti.static(RELATIVISTIC):
            vec = self.lorentz @ vec
        vec.xyz = self.transform @ vec.xyz
        return vec

    @ti.func
    def vec2ether(self, vec: Vec4) -> Vec4:
        """将向量转换到全局坐标系"""
        vec.xyz = self.transform.inverse() @ vec.xyz
        if ti.static(RELATIVISTIC):
            vec = self.lorentz_inv @ vec
        return vec

    @ti.func
    def ray2obj(self, ray: Ray4) -> Ray4:
        """将光线转换到本地坐标系"""
        ray.origin = self.pos2obj(ray.origin)
        ray.direct = self.vec2obj(ray.direct)
        return ray

    @ti.func
    def ray2ether(self, ray: Ray4) -> Ray4:
        """将光线转换到全局坐标系"""
        ray.direct = self.vec2ether(ray.direct)
        ray.origin = self.pos2ether(ray.origin)
        return ray


class Shape(ABC):
    """形状的操作接口"""

    typeId: int = None
    shapes: ti.StructField = ShapeUtil.field(shape=MAX_SHAPE)
    shapesNum: ti.ScalarField = ti.field(int, 1)
    used: "dict[int, type]" = {}
    getId = IdGenerator()

    def __init__(self) -> None:
        if Shape.shapesNum[0] >= MAX_SHAPE:
            raise MemoryError("number of Shapes out of design.")
        if type(self) not in self.used.values():
            type(self).typeId = next(self.getId)
            self.used[self.typeId] = type(self)
        self.index: int = Shape.shapesNum[0]
        Shape.shapesNum[0] += 1
        Shape.shapes[self.index].typeId = self.typeId
        Shape.shapes[self.index].transform = Mat3x3Id
        self.setPosition(0, 0, 0).setVelocity(0, 0, 0).setMaterial(Material.default)

    def setPosition(self, x: float = 0, y: float = 0, z: float = 0):
        """设定形状的位置"""
        Shape.shapes[self.index].position = Vec3(x, y, z)
        return self

    def setMaterial(self, mat: "Material"):
        """设定形状的材质"""
        Shape.shapes[self.index].material = mat.index
        return self

    def setVelocity(self, x: float = 0, y: float = 0, z: float = 0):
        """设定形状的速度"""
        velocity = Vec3(x, y, z)
        Shape.shapes[self.index].velocity = velocity
        Shape.shapes[self.index].lorentz = lorentzBoost(velocity)
        Shape.shapes[self.index].lorentz_inv = lorentzBoost(-velocity)
        return self

    def rotateBy(self, axis: Vec3, angle: float, rad: bool = False):
        """将形状绕指定轴旋转一定角度"""
        axis = Vec3(axis).normalized()
        angle = angle if rad else angle * deg2rad
        c, s = ti.cos(angle), ti.sin(-angle)
        Shape.shapes[self.index].transform @= axis.outer_product(axis) * (
            1 - c
        ) + Mat3x3(
            [
                [c, axis.z * s, -axis.y * s],
                [-axis.z * s, c, axis.x * s],
                [axis.y * s, -axis.x * s, c],
            ]
        )
        return self

    @abstractmethod
    def intersect(self: int, ray: Ray3, closest: float) -> tuple[float, Vec3]:
        """光线与形状的求交, 返回交点距离"""

    @ti.func
    def handleAllIntersect(
        self: int, rayObj: Ray3, closest: float
    ):
        """处理所有形状的求交过程"""
        typeId = Shape.shapes[self].typeId
        dist, norm = Inf, Vec3(0)
        for shape in ti.static(tuple(Shape.used.values())):
            if typeId == shape.typeId:
                dist, norm = shape.intersect(self, rayObj, closest)
        return dist, norm
